package com.alibaba.alink.operator.common.statistics.basicstatistic;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.utils.TableUtil;
import org.apache.flink.types.Row;

import java.util.ArrayList;

/**
 * It is summary result of sparse vector.
 * You can get vectorSize, mean, variance, and other statistics from this class,
 * and get statistics with colName.
 */
public class TableSummary extends BaseSummary {

    /**
     * col names which are calculated.
     */
    String[] colNames;

    /**
     * the number of missing value.
     */
    DenseVector numMissingValue;

    /**
     * sum_i = sum(x_i)
     */
    DenseVector sum;

    /**
     * squareSum_i = sum(x_i * x_i)
     */
    DenseVector squareSum;

    /**
     * min_i = min(x_i)
     */
    DenseVector min;

    /**
     * max_i = max(x_i)
     */
    DenseVector max;

    /**
     * normL1_i = sum(|x_i|)
     */
    DenseVector normL1;

    /**
     * the indices of columns which type is numerical.
     */
    int[] numericalColIndices;

    /**
     * It will generated by summary.
     */
    TableSummary() {

    }


    @Override
    public String toString() {
        String[] outColNames = new String[]{"colName", "count", "numMissingValue", "numValidValue",
            "sum", "mean", "variance", "standardDeviation", "min", "max", "normL1", "normL2"};

        ArrayList data = new ArrayList();

        for (int i = 0; i < colNames.length; i++) {
            Row row = new Row(outColNames.length);

            String colName = colNames[i];
            row.setField(0, colName);
            row.setField(1, count);
            row.setField(2, numMissingValue(colName));
            row.setField(3, numValidValue(colName));
            row.setField(4, sum(colName));
            row.setField(5, mean(colName));
            row.setField(6, variance(colName));
            row.setField(7, standardDeviation(colName));
            row.setField(8, min(colName));
            row.setField(9, max(colName));
            row.setField(10, normL1(colName));
            row.setField(11, normL2(colName));

            data.add(row);
        }

        return TableUtil.format(outColNames, data);
    }

    /**
     * col names which are calculated.
     */
    public String[] getColNames() {
        return this.colNames.clone();
    }

    /**
     * given colName, return sum of the column.
     */
    public double sum(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            return sum.get(idx);
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return mean of the column.
     */
    public double mean(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            double numVaildValue = count - numMissingValue.get(idx);
            if (0 == numVaildValue) {
                return 0;
            }
            return sum.get(idx) / numVaildValue;
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return variance of the column.
     */
    public double variance(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            double numVaildValue = count - numMissingValue.get(idx);
            if (0 == numVaildValue || 1 == numVaildValue) {
                return 0;
            }
            return Math.max(0.0, (squareSum.get(idx) - sum.get(idx) * sum.get(idx) / numVaildValue) / (numVaildValue - 1));
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return standardDeviation of the column.
     */
    public double standardDeviation(String colName) {
        return Math.sqrt(variance(colName));
    }

    /**
     * given colName, return min of the column.
     */
    public double min(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            return min.get(idx);
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return max of the column.
     */
    public double max(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            return max.get(idx);
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return l1 norm of the column.
     */
    public double normL1(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            return normL1.get(idx);
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return l2 norm of the column.
     */
    public double normL2(String colName) {
        int idx = findIdx(colName);
        if (idx >= 0) {
            return Math.sqrt(squareSum.get(idx));
        } else {
            return Double.NaN;
        }
    }

    /**
     * given colName, return the number of valid value.
     */
    public double numValidValue(String colName) {
        return count - numMissingValue(colName);
    }

    /**
     * given colName, return the number of vaild value.
     */
    public double numMissingValue(String colName) {
        int idx = TableUtil.findColIndex(colNames, colName);
        if (idx < 0) {
            throw new RuntimeException(colName + " is not exist.");
        }
        return numMissingValue.get(idx);
    }


    /**
     * given colName, return index of colNames.
     */
    private int findIdx(String colName) {
        int idx = TableUtil.findColIndex(colNames, colName);
        if (idx < 0) {
            throw new RuntimeException(colName + " is not exist.");
        }
        return findIdx(numericalColIndices, idx);
    }

    /**
     * given idx, return idx.
     */
    private static int findIdx(int[] colIndices, int idx) {
        for (int i = 0; i < colIndices.length; i++) {
            if (idx == colIndices[i]) {
                return i;
            }
        }
        return -1;
    }

}
